"""
graph_utils.graph_builder
-------------------------

End-to-end orchestration: builds a weighted symmetric kNN graph with
self-tuned sigmas, shortcut edges, and custom symmetrization.
"""

import numpy as np

from .knn_backends import fast_knn_search
from .sigma_search import self_tuning_sigma
from .weights import compute_weights_vectorized
from .symmetrize import symmetrize_graph
from .kernels import default_kernel, DEFAULT_KERNEL_DEFAULT_PARAMS
from .metrics import resolve as resolve_metric  

def build_weighted_graph(
    X: np.ndarray,
    k: int,
    symmetrize: str = "max",
    use_zero_rho: bool = False,
    backend: str = "auto",
    kernel_function=None,
    kernel_params=None,
    batch_size: int = 6,
    rank_choice: int = 2,
    seed: int = 42,
    *,
    metric: str = "euclidean",              
    metric_params: dict | None = None,        
):

    """
    Build a weighted kNN graph.

    Parameters
    ----------
    X : np.ndarray
        Data matrix (n_samples, n_features)
    k : int
        Number of neighbors for kNN graph
    symmetrize : str
        Symmetrization rule ('max', 'mean', etc)
    use_zero_rho : bool
        If True, force rho = 0
    backend : str
        kNN backend to use ("auto", "faiss", etc)
    kernel_function : callable
        Numba-compiled kernel, or None for default
    kernel_params : np.ndarray
        Parameters for kernel, or None for default
    mid_shortcuts : bool
        Whether to add mid-shortcut edges
    mid_mode : str
        "random" (default) or "2hop"
    mid_scale : float
        Weight scaling for shortcut edges
    mn_ratio : float
        Shortcuts per node = floor(out_degree * mn_ratio)
    batch_size, rank_choice, seed : int
        Advanced options for mid-shortcuts

    Returns
    -------
    rho : np.ndarray
    sigmas : np.ndarray
    ei, ej : np.ndarray
        Final symmetric edge indices
    P_vals : np.ndarray
        Symmetric edge weights
    neigh_idx : np.ndarray
        kNN neighbor indices
    """
    n_samples = X.shape[0]

    # Set kernel defaults
    if kernel_function is None:
        kernel_function = default_kernel
    if kernel_params is None:
        kernel_params = DEFAULT_KERNEL_DEFAULT_PARAMS

        # --- Resolve metric + get backend aliases and code ---
    m = resolve_metric(metric, metric_params or {})
    X_eff, _extras = (m.pre(X, metric_params or {}) if getattr(m, "pre", None) else (X.astype(np.float32, copy=False), {}))

    # --- Pick backend + alias ---
    chosen_backend = backend
    alias = m.backends.get(backend, None)
    if alias is None:
        for cand in ("hnswlib", "pynndescent", "sklearn", "faiss"):
            a = m.backends.get(cand, None)
            if a is not None:
                chosen_backend, alias = cand, a
                break
        if alias is None:
            chosen_backend, alias = "sklearn", "euclidean"


    # 1) kNN search
    distances, indices = fast_knn_search(
        X_eff, int(k), backend=chosen_backend, metric=alias, metric_params=(metric_params or None)
    )
    distances = m.post(distances, chosen_backend)   # standardize to scalar s (Euclidean², 1-cos, etc.)
    distances = distances.astype(np.float32, copy=False)
    indices   = indices.astype(np.int32, copy=False)


    # 2) Rho
    rho = distances[:, 1].astype(np.float32)
    if use_zero_rho:
        rho = np.zeros_like(rho)

    # 3) Sigma
    sigmas = self_tuning_sigma(
        rho, distances[:, 1:], int(k),
        kernel_function=kernel_function,
        kernel_params=kernel_params
    ).astype(np.float32)

    # 4) Compute initial (directed) weights
    neigh_idx = indices[:, 1:].astype(np.int32, copy=False)
    neigh_dist = distances[:, 1:].astype(np.float32, copy=False)

    _SYM_MAP = {"mean": 0, "max": 1, "umap": 2, "geom": 3, "min": 4, "harm": 5, "sinkhorn": 6}
    sym_code = int(_SYM_MAP.get(symmetrize, 0))

    rows, cols, vals = compute_weights_vectorized(
        neigh_idx,
        neigh_dist,
        rho,
        sigmas,
        sym_code,
        n_samples,
        int(k),
        kernel_function,
        kernel_params,
    )

    # 6) Symmetrize
    ei, ej, P_vals = symmetrize_graph(rows, cols, vals, n_samples, symmetrize)

    return rho, sigmas, ei, ej, P_vals, neigh_idx

__all__ = ["build_weighted_graph"]